import os
import sys
import shutil
import tempfile


target = sys.argv[1]

if len(sys.argv) > 2:
    assert sys.argv[2] == 'debug'
    debug = True
else:
    debug = False

maxgap = 10000

if target in ('chrM', 'rRNA', 'tRNA', 'snRNA', 'scRNA', 'histone', 'RPPH', 'snoRNA', 'scaRNA', 'RMRP', 'yRNA', 'snar', 'vRNA', 'TERC', 'MALAT1', 'snhg'):
    score_threshold = 0.8
elif target in ('mRNA', 'lncRNA', 'gencode', 'fantomcat', 'novel', 'genome'):
    score_threshold = 0.9
else:
    raise Exception("Unknown target %s" % target)



filenames1 = []
filenames2 = []
for filename in os.listdir("."):
    terms = filename.split(".")
    if len(terms) != 4:
        continue
    if terms[3] != 'psl':
        continue
    if terms[0] != target:
        continue
    readno = terms[1]
    if readno == "READ1":
        filenames1.append(filename)
    elif readno == "READ2":
        filenames2.append(filename)
    else:
        raise Exception("Unexpected filename %s" % filename)

def keyfunction(filename):
    terms = filename.split(".")
    assert terms[0] == target
    assert terms[1] in ("READ1", "READ2")
    assert terms[3] == "psl"
    start, end = terms[2].split("-")
    start = int(start)
    end = int(end)
    return (start, end)

filenames1.sort(key=keyfunction)
filenames2.sort(key=keyfunction)

current = 0
for filename in filenames1:
    terms = filename.split(".") 
    assert terms[0] == target
    assert terms[1] == "READ1"
    assert terms[3] == "psl"
    start, end = terms[2].split("-")
    start = int(start)
    end = int(end)
    assert start == current
    current = end

current = 0
for filename in filenames2:
    terms = filename.split(".") 
    assert terms[0] == target
    assert terms[1] == "READ2"
    assert terms[3] == "psl"
    start, end = terms[2].split("-")
    start = int(start)
    end = int(end)
    assert start == current
    current = end

n1 = len(filenames1)
n2 = len(filenames2)
assert n1 == n2
if target in ("mRNA", "lncRNA", "gencode", "fantomcat", "novel", "genome"):
    checked = 0
    failed = 0
    errors = 0
    for i in range(n1):
        for readno in ("READ1", "READ2"):
            filename = "script_%s_%s_%d.stderr" % (target, readno, i)
            print("Checking error file %s" % filename)
            with open(filename) as handle:
                for line in handle:
                    for prefix in ("[M::bwa_idx_load_from_disk]", "[M::process]", "[M::mem_process_seqs]", "[main]"):
                        if line.startswith(prefix):
                            break
                    else:
                        raise Exception(line)
            filename = "script_%s_%s_%d.stdout" % (target, readno, i)
            print("Checking output file %s" % filename)
            with open(filename) as handle:
                assert not handle.read()
            filename = "%s.%s_%d.out" % (target, readno, i)
            print("Checking pslCheck output file %s" % filename)
            with open(filename) as handle:
                line = handle.read()
                if line == "stdin is empty\n":
                    pass
                else:
                    words = line.split()
                    assert len(words) == 6
                    assert words[0] == "checked:"
                    assert words[2] == "failed:"
                    assert words[4] == "errors:"
                    checked += int(words[1])
                    failed += int(words[3])
                    errors += int(words[5])
    assert failed == 0
    assert errors == 0
    print("Total checked: %d failed: %d errors: %d" % (checked, failed, errors))
else:
    for i in range(n1):
        for readno in ("READ1", "READ2"):
            filename = "script_%s_%s_%d.stderr" % (target, readno, i)
            print("Checking error file %s" % filename)
            with open(filename) as handle:
                assert not handle.read()
            filename = "script_%s_%s_%d.stdout" % (target, readno, i)
            print("Checking output file %s" % filename)
            with open(filename) as handle:
                for line in handle:
                    pass
                assert line.strip() == 'Done'

alignments1 = {}
alignments2 = {}

if target == "genome":
    strands = "+-"
else:
    strands = "+"

for filenames, alignments in zip([filenames1, filenames2],
                                 [alignments1, alignments2]):
    for filename in filenames:
        print("Reading", filename)
        handle = open(filename)
        for alignment in handle:
            terms = alignment.split()
            assert len(terms) == 21
            strand = terms[8]
            assert strand in strands
            qName = terms[9]
            tName = terms[13]
            if qName not in alignments:
                alignments[qName] = {}
            if tName not in alignments:
                alignments[qName][tName] = []
            alignments[qName][tName].append(alignment)
        handle.close()

query_length = 33

libraries = []
for filename in os.listdir("."):
    terms = filename.split(".")
    if len(terms) != 3:
        continue
    if terms[1] != "index":
        continue
    if terms[2] != "txt":
        continue
    library = terms[0]
    libraries.append(library)

libraries.sort()

header = """\
psLayout version 3

match	mis- 	rep. 	N's	Q gap	Q gap	T gap	T gap	strand	Q        	Q   	Q    	Q  	T        	T   	T    	T  	block	blockSizes 	qStarts	 tStarts
     	match	match	   	count	bases	count	bases	      	name     	size	start	end	name     	size	start	end	count
---------------------------------------------------------------------------------------------------------------------------------------------------------------
"""

readnos1 = []
readnos2 = []
old_new_filenames = []
for library in libraries:
    filename = "%s.index.txt" % library
    print("Reading %s" % filename)
    handle = open(filename)
    filename = "%s.%s.psl" % (library, target)
    print("Writing %s" % filename)
    output = open(filename, 'w')
    output.write(header)
    for line in handle:
        qName, qName1_qName2 = line.split()
        qName1, qName2 = qName1_qName2.split(",")
        if qName1 not in alignments1:
            continue
        if qName2 not in alignments2:
            continue
        tNames = set(alignments1[qName1]).intersection(alignments2[qName2])
        score = score_threshold * 2 * query_length
        pairs = []
        for tName in tNames:
            for alignment1 in alignments1[qName1][tName]:
                words1 = alignment1.split()
                assert len(words1) == 21
                matches1 = int(words1[0])
                misMatches1 = int(words1[1])
                qBaseInsert1 = int(words1[5])
                tBaseInsert1 = int(words1[7])
                strand1 = words1[8]
                qSize1 = int(words1[10])
                qStart1 = int(words1[11])
                qEnd1 = int(words1[12])
                clipped1 = qStart1 + qSize1 - qEnd1
                score1 = matches1 - misMatches1 - qBaseInsert1 - tBaseInsert1 - clipped1
                tSize1 = int(words1[14])
                tStart1 = int(words1[15])
                tEnd1 = int(words1[16])
                for alignment2 in alignments2[qName2][tName]:
                    words2 = alignment2.split()
                    assert len(words2) == 21
                    matches2 = int(words2[0])
                    misMatches2 = int(words2[1])
                    qBaseInsert2 = int(words2[5])
                    tBaseInsert2 = int(words2[7])
                    strand2 = words2[8]
                    qSize2 = int(words2[10])
                    qStart2 = int(words2[11])
                    qEnd2 = int(words2[12])
                    tSize2 = int(words2[14])
                    tStart2 = int(words2[15])
                    tEnd2 = int(words2[16])
                    clipped2 = qStart2 + qSize2 - qEnd2
                    score2 = matches2 - misMatches2 - qBaseInsert2 - tBaseInsert2 - clipped2
                    if strand1 == "+":
                        if strand2 == "+":
                            gap = tStart2 - tEnd1
                            extent = tEnd2 - tStart1
                        else:
                            assert strand2 == "-"
                            continue
                    else:
                        assert strand1 == "-"
                        if strand2 == "-":
                            gap = tStart1 - tEnd2
                            extent = tEnd1 - tStart2
                        else:
                            assert strand2 == "+"
                            continue
                    assert tSize1 == tSize2
                    if gap < 0:
                        continue
                    if gap > maxgap:
                        continue
                    if score1 + score2 < score:
                        continue
                    elif score1 + score2 > score:
                        score = score1 + score2
                        shortest_extent = extent
                        shortest_tSize = tSize1
                        pairs.clear()
                    else:
                        assert score1 + score2 == score
                        if extent > shortest_extent:
                            continue
                        elif extent < shortest_extent:
                            shortest_extent = extent
                            shortest_tSize = tSize1
                            pairs.clear()
                        else:
                            assert extent == shortest_extent
                            if target != 'genome':
                                if tSize1 > shortest_tSize:
                                    continue
                                elif tSize1 < shortest_tSize:
                                    shortest_tSize = tSize1
                                    pairs.clear()
                                else:
                                    assert tSize1 == shortest_tSize
                    pair = (alignment1, alignment2)
                    pairs.append(pair)
        for pair in pairs:
            alignment1, alignment2 = pair
            words1 = alignment1.split()
            assert len(words1) == 21
            matches1 = int(words1[0])
            misMatches1 = int(words1[1])
            repMatches1 = int(words1[2])
            nCount1 = int(words1[3])
            qNumInsert1 = int(words1[4])
            qBaseInsert1 = int(words1[5])
            tNumInsert1 = int(words1[6])
            tBaseInsert1 = int(words1[7])
            assert qName1 == words1[9]
            qSize1 = int(words1[10])
            qStart1 = int(words1[11])
            qEnd1 = int(words1[12])
            tName1 = words1[13]
            tSize1 = int(words1[14])
            tStart1 = int(words1[15])
            tEnd1 = int(words1[16])
            blockCount1 = int(words1[17])
            blockSizes1 = words1[18]
            qStarts1 = words1[19]
            tStarts1 = words1[20]
            clipped1 = qStart1 + qSize1 - qEnd1
            score1 = matches1 - misMatches1 - qBaseInsert1 - tBaseInsert1 - clipped1
            words2 = alignment2.split()
            assert len(words2) == 21
            matches2 = int(words2[0])
            misMatches2 = int(words2[1])
            repMatches2 = int(words2[2])
            nCount2 = int(words2[3])
            qNumInsert2 = int(words2[4])
            qBaseInsert2 = int(words2[5])
            tNumInsert2 = int(words2[6])
            tBaseInsert2 = int(words2[7])
            qSize2 = int(words2[10])
            qStart2 = int(words2[11])
            qEnd2 = int(words2[12])
            tName2 = words2[13]
            tSize2 = int(words2[14])
            tStart2 = int(words2[15])
            tEnd2 = int(words2[16])
            blockCount2 = int(words2[17])
            blockSizes2 = words2[18]
            qStarts2 = words2[19]
            tStarts2 = words2[20]
            clipped2 = qStart2 + qSize2 - qEnd2
            score2 = matches2 - misMatches2 - qBaseInsert2 - tBaseInsert2 - clipped2
            assert tName1 == tName2
            assert tSize1 == tSize2
            assert score1 + score2 == score
            assert score > score_threshold * 2 * query_length
            strand1 = words1[8]
            strand2 = words2[8]
            if target == 'genome':
                if strand1 == "+":
                    assert strand2 == "+"
                    strand2 = "-"
                elif strand1 == "-":
                    assert strand2 == "-"
                    strand2 = "+"
                else:
                    raise ValueError("Unexpected value '%s' for strand2" % strand2)
            else:
                assert strand1 == "+"
                assert strand2 == "+"
                strand2 = "-"
            words2[8] = strand2
            qStart2, qEnd2 = qSize2 - qEnd2, qSize2 - qStart2
            words2[11] = str(qStart2)
            words2[12] = str(qEnd2)
            assert qName1 == words1[9]
            assert qName2 == words2[9]
            words1[9] = qName
            words2[9] = qName
            line1 = "\t".join(words1) + "\n"
            line2 = "\t".join(words2) + "\n"
            output.write(line1)
            output.write(line2)
    output.close()
    handle.close()
    qNames = []
    print("Reading %s" % filename)
    handle = open(filename)
    line1 = next(handle)
    line2 = next(handle)
    line3 = next(handle)
    line4 = next(handle)
    line5 = next(handle)
    assert line1 + line2 + line3 + line4 + line5 == header
    for line in handle:
        words = line.split()
        assert len(words) == 21
        qName = words[9]
        qNames.append(qName)
    handle.close()
    qNames = set(qNames)
    print("Number of mapped sequences: %d" % len(qNames))
    filename = "%s.index.txt" % library
    print("Reading %s" % filename)
    handle = open(filename)
    output = tempfile.NamedTemporaryFile(delete=False, mode='wt')
    counter = 0
    for line in handle:
        qName, readno1_readno2 = line.split()
        if qName in qNames:
            continue
        output.write(line)
        counter += 1
        readno1, readno2 = readno1_readno2.split(",")
        readnos1.append(readno1)
        readnos2.append(readno2)
    handle.close()
    output.close()
    print("Number of remaining sequences: %d" % counter)
    old_new_filenames.append([output.name, filename])


readnos1 = set(readnos1)
readnos2 = set(readnos2)
print("Number of remaining unique sequences: READ1: %d, READ2: %d" % (len(readnos1), len(readnos2)))

filename = "seqlist_READ1.fa"
print("Reading", filename)
handle = open(filename)
output = tempfile.NamedTemporaryFile(delete=False, mode='wt')
for line1 in handle:
    line2 = next(handle)  # sequence
    assert line1[0] == ">"
    readno1 = line1[1:].strip()
    if readno1 in readnos1:
        output.write(line1)
        output.write(line2)
output.close()
handle.close()
old_new_filenames.append([output.name, filename])

filename = "seqlist_READ2.fa"
print("Reading", filename)
handle = open(filename)
output = tempfile.NamedTemporaryFile(delete=False, mode='wt')
for line1 in handle:
    line2 = next(handle)  # sequence
    assert line1[0] == ">"
    readno2 = line1[1:].strip()
    if readno2 in readnos2:
        output.write(line1)
        output.write(line2)
output.close()
handle.close()
old_new_filenames.append([output.name, filename])

if not debug:
    # Overwrite files as the last steps, to avoid losing everything in a crash
    for old_filename, new_filename in old_new_filenames:
        print("Overwriting %s" % new_filename)
        shutil.move(old_filename, new_filename)
